In this tutorial, we will be looking at how we can fine-tune a Vision Transformer model for multi-label image classification.
In Multi-label classification each image in our dataset will have 1 or more than 1 class labels unlike multi-class classification where each image only has 1 label.
For this tutorial we'll be fine-tuning a Swin Transformer, specifically swin_s3_base_224 from the Hugging Face Timm library to obtain our pre-trained model.
For the dataset, we are going for Pascal VOC 2007 dataset, which includes annotations for both multi-label classification and object detection.
In this tutorial, we'll also be using Hugging Face accelerate to power our training loops and for calculating metrics, we'll use Hugging Face evaluate
Since we are using accelerate, it enables us to write loops that work in a distributed configuration as well automatically with support for Mixed-Precision, FSDP, DeepSpeed, etc. Although in this notebook, we'll keep it pretty simple.
!pip install -Uq transformers datasets timm accelerate evaluate
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as T
from pathlib import Path
from PIL import Image
import datasets
from transformers.optimization import get_cosine_schedule_with_warmup
from timm import list_models, create_model
from accelerate import Accelerator, notebook_launcher
import evaluate
We'll download the Pascal VOC 2007 dataset from the Hugging Face hub using datasets library.
dataset = datasets.load_dataset('fuliucansheng/pascal_voc','voc2007_main')
dataset
The dataset contains the following features:
PIL.Image format image.For this tutorial, we'll only be needing the image and the classes feature from the dataset.
These are all the unique the unique labels that are available in the dataset, 20 of them in total. Each image might have more than 1 label associated to it.
Since in the dataset, the classes are given in integer format. We'll create two mappings, label2id and id2label to convert the labels to their IDs and vice versa. It will make it easy for us to understand the labels during visualization.
class_names = [
"Aeroplane","Bicycle","Bird","Boat","Bottle",
"Bus","Car","Cat","Chair","Cow","Diningtable",
"Dog","Horse","Motorbike","Person",
"Potted plant","Sheep","Sofa","Train","Tv/monitor"
]
label2id = {c:idx for idx,c in enumerate(class_names)}
id2label = {idx:c for idx,c in enumerate(class_names)}
For any dataset we use with the datasets library, we can shuffle the dataset using shuffle() and, select any samples using the select() method.
As you'll notice, some of the images have more than 1 label.
def show_samples(ds,rows,cols):
samples = ds.shuffle().select(np.arange(rows*cols)) # selecting random images
fig = plt.figure(figsize=(cols*4,rows*4))
# plotting
for i in range(rows*cols):
img = samples[i]['image']
labels = samples[i]['classes']
# getting string labels and combining them with a comma
labels = ','.join([id2label[lb] for lb in labels])
fig.add_subplot(rows,cols,i+1)
plt.imshow(img)
plt.title(labels)
plt.axis('off')
show_samples(dataset['train'],rows=5,cols=5)
When it comes to image datasets, preprocessing involves multiple steps. Let's discuss them in detail. To apply these image and label transformations, we will define train_transforms and valid_transforms functions to preprocess a batch of samples during traning.
This includes transforms such as resizing all images to have the same dimensions, normalizing, and scaling the pixel values to a uniform range. We can also add augmentations to our images like random flips, rotations, perspectives, etc.
For our transforms and augmentations, we'll be using torchvision.
Note: we apply random augmentations such as flips, rotations, etc. to our training dataset only. Hence we'll create two different transforms train_tfms for training and valid_tfms for validation and testing.
The transforms are as follows:
swin_s3_base_224 indicates the image input size should be 224x224, so we'll resize accordingly.Since for each sample, we have multiple labels, we'll be using one-hot encoding which will transform our list of labels into a vector of 0s and 1s. The length of the vector will be equal to the number of labels and, at the index of the labels, the value will be 1 and remaining as 0s.
Example:
label: [3,5], num_labels = 10
one-hot encoded label: [0 0 0 1 0 1 0 0 0 0], at index 3 and 5 the value will be 1
To do this in PyTorch, we'll be using torch.nn.functional.one_hot which works in the following manner:
When we pass a sample/batch to train_transforms or valid_transforms function. The classes will be in the form [[3,5]], a list of sample labels. First we convert this list of lists into a tensor and then one-hot encode it.
Example:
>>> sample_batch = [[2,14]] # batch with 1 sample
>>> labels = torch.tensor(sample_batch)
>>> labels
>>> tensor([[ 2, 14]])
>>> labels = nn.functional.one_hot(labels, num_classes=20) # provide total classes
>>> labels
>>> tensor([[[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]]])
>>> labels = labels.sum(dim=1) # sum along dim=1 to get a flattened one-hot encoding
>>> labels
>>> tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]) # index 2,14 have value 1.
When we apply this one-hot encoding, we are essentially converting this problem into a binary classification problem for each label.
img_size = (224,224)
train_tfms = T.Compose([
T.Resize(img_size),
T.RandomHorizontalFlip(),
T.RandomRotation(30),
T.CenterCrop(img_size),
T.ToTensor(),
T.Normalize(
mean = (0.5,0.5,0.5),
std = (0.5,0.5,0.5)
)
])
valid_tfms = T.Compose([
T.Resize(img_size),
T.ToTensor(),
T.Normalize(
mean = (0.5,0.5,0.5),
std = (0.5,0.5,0.5)
)
])
def train_transforms(batch):
# Convert all images to RGB format
if isinstance(batch['image'], list):
# Batch processing
batch['image'] = [x.convert('RGB') for x in batch['image']]
inputs = [train_tfms(x) for x in batch['image']]
batch['pixel_values'] = torch.stack(inputs) # Stack tensor outputs
else:
# Single sample processing
batch['image'] = batch['image'].convert('RGB')
batch['pixel_values'] = train_tfms(batch['image'])
# One-hot encode the multilabels
all_labels = [torch.tensor(labels) for labels in batch['classes']]
# Create one-hot encoding for each image's classes
one_hot_labels = [nn.functional.one_hot(label, num_classes=20).sum(dim=0) for label in all_labels]
# Stack them into a batch
batch['labels'] = torch.stack(one_hot_labels)
return batch
def valid_transforms(batch):
# Convert all images to RGB format
if isinstance(batch['image'], list):
# Batch processing
batch['image'] = [x.convert('RGB') for x in batch['image']]
inputs = [train_tfms(x) for x in batch['image']]
batch['pixel_values'] = torch.stack(inputs) # Stack tensor outputs
else:
# Single sample processing
batch['image'] = batch['image'].convert('RGB')
batch['pixel_values'] = train_tfms(batch['image'])
# One-hot encode the multilabels
all_labels = [torch.tensor(labels) for labels in batch['classes']]
# Create one-hot encoding for each image's classes
one_hot_labels = [nn.functional.one_hot(label, num_classes=20).sum(dim=0) for label in all_labels]
# Stack them into a batch
batch['labels'] = torch.stack(one_hot_labels)
return batch
We'll pair the preprocessing functions with our datasets using with_transform method.
train_dataset = dataset['train'].with_transform(train_transforms)
valid_dataset = dataset['validation'].with_transform(valid_transforms)
test_dataset = dataset['test'].with_transform(valid_transforms)
len(train_dataset), len(valid_dataset), len(test_dataset)
Batching our data in the correct format is collation. For pixel_values, the input shape for the model should be (batch, channels, height, width) and for our one-hot encoded labels, the shape should be (batch,num_labels)
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.stack([x['labels'] for x in batch]).float()
}
A handy function to calculate the amount of trainable parameters in our model
def param_count(model):
params = [(p.numel(),p.requires_grad) for p in model.parameters()]
trainable = sum([count for count,trainable in params if trainable])
total = sum([count for count,_ in params])
frac = (trainable / total) * 100
return total, trainable, frac
timm¶timm library using timm.create_model.timm.create_model arguments:
To list various available models, we can use timm.list_models. We can pass in a string pattern such as *swin* or *vit* which will match all model names available with the pattern. You also also pass pretrained=True to only list models with available pretrained weights. Example: timm.list_models("*swin*",pretrained=True)
DataLoaders:
torch.utils.data.DataLoader.Loss Function:
nn.BCEWithLogitsLoss() which will take our predictions and targets of the shape (batch, num_labels)Model:
timm library.Optimizer, Scheduler:
get_cosine_schedule_with_warmup from transformers.optimization. In this scheduler, the learning rate increases gradually till num_warmup_steps and decays for the remaining steps with cosine annealing.Metrics:
evaluate library. We will be using roc_auc metric for multilabel with micro averaging which will calculate the metrics globally. For more explanation and references about the metric, check this evaluate space.accelerator instance with Accelerator() along with any further configuration kwargs.accelerator.prepare method.accelerator.gather_for_metrics to do so.accelerator.print.Since we'll be running from our Jupyter notebook, we'll be using notebook_launcher, which will call our train function that contains all of our logic and accelerator instance.
For further information and details on how to use accelerate, checkout the docs and this handy HF space.
def train(model_name,batch_size=16,epochs=1,lr=2e-4):
"""
contains all of our training loops.
1. define Accelerator instance
2. define dataloaders, model, optimizer, loss function, scheduler
3. write training, validation and testing loop.
"""
accelerator = Accelerator() # create instance
# define dataloaders
train_dl = torch.utils.data.DataLoader(
train_dataset,
batch_size = batch_size, # the batch_size will be per-device
shuffle=True,
num_workers=4,
collate_fn=collate_fn
)
valid_dl = torch.utils.data.DataLoader(
valid_dataset,
batch_size = batch_size*2,
shuffle=False,
num_workers=4,
collate_fn=collate_fn
)
test_dl = torch.utils.data.DataLoader(
test_dataset,
batch_size = batch_size*2,
shuffle=False,
num_workers=4,
collate_fn=collate_fn
)
# timm model
model = create_model(
model_name,
pretrained = True,
num_classes = 20
).to(accelerator.device) # device placement: accelerator.device
total, trainable, frac = param_count(model)
accelerator.print(f"{total = :,} | {trainable = :,} | {frac:.2f}%")
# loss, optimizer, scheduler
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(),lr=lr,weight_decay=0.02)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps = int(0.1 * len(train_dl)),
num_training_steps=len(train_dl)
)
model, optimizer, scheduler, train_dl, valid_dl, test_dl = accelerator.prepare(
model, optimizer, scheduler, train_dl, valid_dl, test_dl
)
# loops for number of epochs
for epoch in range(1,epochs+1):
model.train() # set model to train
train_metric = evaluate.load('roc_auc','multilabel') # load metric
running_loss = 0.
for batch in train_dl:
logits = model(batch['pixel_values'])
loss = loss_fn(logits,batch['labels'])
accelerator.backward(loss) # backpropagation
optimizer.step() # update weights
scheduler.step() # update LR
optimizer.zero_grad() # set grad values to zero
running_loss += loss.item() # keep track of loss
# prepare for metrics
logits, labels = accelerator.gather_for_metrics(
(logits, batch['labels'])
)
train_metric.add_batch(references=labels, prediction_scores=logits)
# loss and metric over 1 epoch
train_loss = running_loss / len(train_dl)
train_roc_auc = train_metric.compute(average='micro')['roc_auc']
accelerator.print(f"\n{epoch = }")
accelerator.print(f"{train_loss = :.3f} | {train_roc_auc = :.3f}")
# validation loop
model.eval() # set model for evaluation
running_loss = 0.
valid_metric = evaluate.load('roc_auc','multilabel')
for batch in valid_dl:
with torch.no_grad():
logits = model(batch['pixel_values'])
loss = loss_fn(logits, batch['labels'])
running_loss += loss.item()
logits, labels = accelerator.gather_for_metrics(
(logits, batch['labels'])
)
valid_metric.add_batch(references=labels, prediction_scores=logits)
valid_loss = running_loss / len(valid_dl)
valid_roc_auc = valid_metric.compute(average='micro')['roc_auc']
accelerator.print(f"{valid_loss = :.3f} | {valid_roc_auc = :.3f}")
# save model
accelerator.save_model(model, f'./{model_name}-pascal')
# testing loop after all epochs are over
test_metric = evaluate.load('roc_auc','multilabel')
for batch in test_dl:
with torch.no_grad():
logits = model(batch['pixel_values'])
logits, labels = accelerator.gather_for_metrics(
(logits, batch['labels'])
)
test_metric.add_batch(references=labels, prediction_scores=logits)
test_roc_auc = test_metric.compute(average='micro')['roc_auc']
accelerator.print(f"\n\nTEST AUROC: {test_roc_auc:.3f}")
with notebook_launcher, we start the training procedure by calling our train function with the args (model_name, batch_size, epochs, lr) as we defined above, and num_processes equal to the amount of GPUs.
model_name = 'swin_s3_base_224'
notebook_launcher(train, (model_name,8,5,5e-5), num_processes = 2)
accelerate.save_model, it saves it in safetensors format.# intialize the model
model = create_model(
model_name,
num_classes=20
)
from safetensors.torch import load_model
load_model(model,f'./{model_name}-pascal/model.safetensors')
def show_predictions(rows=2,cols=4):
model.eval()
samples = test_dataset.shuffle().select(np.arange(rows*cols))
fig = plt.figure(figsize=(cols*4,rows*4))
for i in range(rows*cols):
img = samples[i]['image']
inputs = samples[i]['pixel_values'].unsqueeze(0)
labels = samples[i]['classes']
labels = ','.join([id2label[lb] for lb in labels])
with torch.no_grad():
logits = model(inputs)
# apply sigmoid activation to convert logits to probabilities
# getting labels with confidence threshold of 0.5
predictions = logits.sigmoid() > 0.5
# converting one-hot encoded predictions back to list of labels
predictions = predictions.float().numpy().flatten() # convert boolean predictions to float
pred_labels = np.where(predictions==1)[0] # find indices where prediction is 1
pred_labels = ','.join([id2label[label] for label in pred_labels]) # converting integer labels to string
label = f"labels: {labels}\npredicted: {pred_labels}"
fig.add_subplot(rows,cols,i+1)
plt.imshow(img)
plt.title(label)
plt.axis('off')
show_predictions(rows=5,cols=5)